import tensorflow as tf
from env1 import Env
import numpy as np
import datetime
import pandas as pd
import os
from Network_MBAC import Network
from memory import ReplayMemory

units = 256
LR_A = 0.0003    # learning rate for actor
LR_C1 = 0.0003    # learning rate for critic
LR_C2 = 0.0003    # learning rate for critic
LR_L = 0.0003    # learning rate for actor
LR_R = 0.0003    # learning rate for actor
TAU_a = 1      # soft replacement
TAU_c1 = 0.005      # soft replacement
TAU_c2 = 0.005      # soft replacement
TAU_l = 1      # soft replacement
Seed_num = 1
arm_num = 1
Update_actor_freq = 1
Limit = np.inf
interval = 500
TEST = 10 # The number of experiment test every 100 episode
Runout = 10000
label = 'Arm_%i' % arm_num
Reset = False
On_train = True
Conti_train = True
observe = 1500000
observe1 = 1500000

sample_mini = 1
class Worker(object):
    Memory_size = 1500000
    batch_mini = 256
    train_max = 1500000
    train_max_mini = sample_mini
    multi_step = 1
    global Algo
    if sample_mini == 1:
        Algo = 'MBAC1'
    elif Memory_size == sample_mini:
        Algo = 'MMBAC'
    else:
        Algo = 'MMBAC_v1'
    def __init__(self, name, globalAC):
        self.s = None
        if name == 'Worker_0':
            self.memory = ReplayMemory(self.Memory_size, Seed)
        self.name = name
        self.env = Env(seed=Seed)
        self.globalAC = globalAC
        self.AC = Network(self.env, TAU_a, 0, 0, TAU_l, LR_A, 0, 0, LR_L, LR_R, 0, arm_num, Seed, Seed_num, Algo, units, interval, TEST, Runout, name, SESS, Update_actor_freq, TAU_c1, TAU_c2, LR_C1, LR_C2, observe)

    def experience_replay(self):
        for i in range(self.train_max_mini):
            bs, ba, bs_, br, bd = self.memory.sample(self.batch_mini)
            if self.globalAC.train_counter >= self.train_max:
                break
            else:
                self.globalAC.var_run(bs, ba, bs_, br, bd)


def train():
    for worker in workers:
        worker.s = worker.env.reset_mul(lim=Limit)
    while GLOBAL_AC.train_counter < workers[0].train_max:
        for worker in workers:
            a = GLOBAL_AC.choose_action_exp(worker.s)
            if GLOBAL_AC.train_counter >= observe1:
                # s_next, r_on, a_alt = GLOBAL_AC.step(worker.s, a)
                # done = worker.env.done(s_next[0])
                # workers[0].memory.push(worker.s[0], a_alt[0], s_next[0], r_on[0], done[0])
                # worker.s = s_next
                pass
            else:
                s_, r, _, done = worker.env.step_mul(a)
                workers[0].memory.push(worker.s[0], a[0], s_[0], r[0], done[0])
                worker.s = s_
                if done:
                    worker.s = worker.env.reset_mul()
        if len(workers[0].memory) > workers[0].batch_mini:
            workers[0].experience_replay()

if __name__ == "__main__":
    if not os.path.exists('./data/' + label):
        os.makedirs('./data/' + label)
    start_time = datetime.datetime.now()
    Seed = 0

    if Reset:
        # 初始化seed
        seed = pd.DataFrame(np.array([Seed]))
        seed.to_csv('./data/' + label + '/seed_' + Algo + '.csv', index=False, header=True)
        tf.reset_default_graph()
        tf.set_random_seed(Seed)
        SESS = tf.InteractiveSession()

        # 初始化参数与网络
        env = Env()
        GLOBAL_AC = Network(env, TAU_a, 0, 0, TAU_l, LR_A, 0, 0, LR_L, LR_R, 0, arm_num, Seed, Seed_num, Algo, units, interval, TEST, Runout, 'global', SESS, Update_actor_freq, TAU_c1, TAU_c2, LR_C1, LR_C2, observe)
        workers = []
        # Create worker
        for i in range(sample_mini):
            i_name = 'Worker_%i' % i  # worker name
            workers.append(Worker(i_name, GLOBAL_AC))
            if not os.path.exists('./data/' + label + '/workers_' + Algo + '/' + i_name):
                os.makedirs('./data/' + label + '/workers_' + Algo + '/' + i_name)

        SESS.run(tf.global_variables_initializer())
        for worker in workers:
            worker.AC.plot_cost()
        GLOBAL_AC.plot_cost()
        SESS.close()
    else:
        if Conti_train or not On_train:
            Seed = int(np.array(pd.read_csv('./data/' + label + '/seed_' + Algo + '.csv')).reshape(-1)[0])
            Conti_train = False
        elif not Conti_train:
            seed = pd.DataFrame(np.array([Seed]))
            seed.to_csv('./data/' + label + '/seed_' + Algo + '.csv', index=False, header=True)

        while Seed < Seed_num:
            np.random.seed(Seed)
            tf.reset_default_graph()
            tf.set_random_seed(Seed)
            SESS = tf.InteractiveSession()

            env = Env(seed=Seed)
            GLOBAL_AC = Network(env, TAU_a, 0, 0, TAU_l, LR_A, 0, 0, LR_L, LR_R, 0, arm_num, Seed, Seed_num, Algo, units, interval, TEST, Runout, 'global', SESS, Update_actor_freq, TAU_c1, TAU_c2, LR_C1, LR_C2, observe)
            workers = []
            # Create worker
            for i in range(sample_mini):
                i_name = 'Worker_%i' % i  # worker name
                workers.append(Worker(i_name, GLOBAL_AC))
                if not os.path.exists('./data/' + label + '/workers_' + Algo + '/' + i_name):
                    os.makedirs('./data/' + label + '/workers_' + Algo + '/' + i_name)

            SESS.run(tf.global_variables_initializer())
            if On_train:
                if Conti_train:
                    for worker in workers:
                        worker.AC.load_par()
                print(GLOBAL_AC.scope, '---第' + ' ' + str(Seed) + ' ' + '轮开始---')
                train()
                for worker in workers:
                    worker.AC.plot_cost()
                print(GLOBAL_AC.scope, '---第' + ' ' + str(Seed) + ' ' + '轮结束---')
            GLOBAL_AC.plot_cost()
            SESS.close()

            Conti_train = False
            Seed += 1
            seed = pd.DataFrame(np.array([Seed]))
            seed.to_csv('./data/' + label + '/seed_' + Algo + '.csv', index=False, header=True)

            end_time = datetime.datetime.now()
            print('\n----------用时', int((end_time - start_time).seconds / 60), '分', (end_time - start_time).seconds % 60, '秒-----------')

        if Seed >= Seed_num:
            tf.reset_default_graph()
            SESS = tf.InteractiveSession()
            env = Env(seed=Seed)
            GLOBAL_AC = Network(env, TAU_a, 0, 0, TAU_l, LR_A, 0, 0, LR_L, LR_R, 0, arm_num, Seed, Seed_num, Algo, units, interval, TEST, Runout, 'global', SESS, Update_actor_freq, TAU_c1, TAU_c2, LR_C1, LR_C2, observe)
            SESS.run(tf.global_variables_initializer())
            GLOBAL_AC.eval()
            SESS.close()